import os
import wx
import random
import time

from control_panel import *


OPEN_STATE_NOT_OPEN         = 0
OPEN_STATE_OPENED           = 1

BLOCK_FLAG_NONE             = 0
BLOCK_FLAG_REDFLAG          = 1
BLOCK_FLAG_QUESTION_MARK    = 2

CONTROL_PANEL_HEIGHT  = 40
MIN_BLOCK_HEIGHT      = 2


class MineBlock:
    def __init__(self, open_state, block_flag):
        self.open_state = open_state
        self.block_flag = block_flag
        self.around_num = 0
        self.has_mine   = False


class MinePanel(wx.Panel):
    def __init__(self, parent, sz, st):
        super().__init__(parent, size=sz, style=st)
        self.Bind(wx.EVT_PAINT, self._on_paint)
        self.Bind(wx.EVT_SIZE, self._on_size)
        self.Bind(wx.EVT_LEFT_DOWN, self._on_mouse_left_down)
        self.Bind(wx.EVT_LEFT_UP, self._on_mouse_left_up)
        self.Bind(wx.EVT_LEFT_DCLICK, self._on_mouse_left_dbclick)
        self.Bind(wx.EVT_RIGHT_DOWN, self._on_mouse_right_down)
        self.Bind(wx.EVT_RIGHT_UP, self._on_mouse_right_up)
        self.timer = wx.Timer(self)
        self.Bind(wx.EVT_TIMER, self._on_timer, self.timer)
        
        self.SetBackgroundStyle(wx.BG_STYLE_PAINT)
        self.num_colors     = [None, wx.Colour(0, 0, 230), wx.Colour(0, 180, 0), wx.Colour(210, 0, 0), wx.Colour(220, 50, 180), \
                               wx.Colour(30, 200, 200), wx.Colour(240, 240, 30), wx.Colour(160, 80, 180), wx.Colour(0, 0, 0)]
        self.control_panel  = ControlPanel(self)
        self.block_height   = 0
        self.red_flag_count = 0
        self.num_txt_font   = None
        self.left_down      = False
        self.right_down     = False
        self.red_flag_icon  = None
        self.dc_buffer      = None
        red_flag_icon_file  = 'res' + os.sep + 'red_flag.png'
        if (os.path.exists(red_flag_icon_file)):
            self.red_flag_icon = wx.Bitmap(red_flag_icon_file)

    def reset_with_current_mode(self):
        self.reset_blocks(self.rows_num, self.cols_num, self.mines_num)
        self.Refresh()

    def reset_blocks(self, rows_num, cols_num, mines_num):
        self.die            = False
        self.success        = False
        self.start_time     = 0
        self.last_time      = 0
        self.rows_num       = rows_num
        self.cols_num       = cols_num
        self.mines_num      = mines_num
        self.blocks_map     = []
        self.block_height   = 0
        self.red_flag_count = 0
        self.left_down      = False
        self.right_down     = False
        blocks_num          = rows_num * cols_num
        self._init_important_data()

        for i in range(blocks_num):
            self.blocks_map.append(MineBlock(OPEN_STATE_NOT_OPEN, BLOCK_FLAG_NONE))
        left_mines_num = mines_num
        while left_mines_num > 0:
            locate = random.randint(0, blocks_num-1)
            if (self.blocks_map[locate].has_mine):
                continue
            self.blocks_map[locate].has_mine = True
            left_mines_num -= 1
        self._calc_around_nums()

    def _calc_around_nums(self):
        blocks_num = self.rows_num * self.cols_num
        for i in range(blocks_num):
            center_x = i % self.cols_num
            center_y = i // self.cols_num
            around_num = 0
            
            around_num += self._get_mine_num_in_block(center_x - 1, center_y - 1)
            around_num += self._get_mine_num_in_block(center_x,     center_y - 1)
            around_num += self._get_mine_num_in_block(center_x + 1, center_y - 1)
            
            around_num += self._get_mine_num_in_block(center_x - 1, center_y)
            around_num += self._get_mine_num_in_block(center_x + 1, center_y)
            
            around_num += self._get_mine_num_in_block(center_x - 1, center_y + 1)
            around_num += self._get_mine_num_in_block(center_x,     center_y + 1)
            around_num += self._get_mine_num_in_block(center_x + 1, center_y + 1)
            
            self.blocks_map[i].around_num = around_num

    def _get_mine_num_in_block(self, x_index, y_index):
        if (x_index < 0 or x_index >= self.cols_num or y_index < 0 or y_index >= self.rows_num):
            return 0
        block_locate = y_index * self.cols_num + x_index
        if (self.blocks_map[block_locate].has_mine):
            return 1
        return 0
        
    def _init_control_panel_important_data(self, client_rect, residue_count, consume_time):
        self.control_panel.set_upper_height(CONTROL_PANEL_HEIGHT, client_rect)
        self.control_panel.set_residue_mines_count(residue_count, False)
        self.control_panel.set_consume_time(consume_time, False)

    def _calc_block_height(self, client_rect):
        h1 = (client_rect.GetHeight()-self.control_panel.occupy_height()) // self.rows_num
        h2 = client_rect.GetWidth()  // (self.cols_num + 2)
        min_h = min(h1, h2)
        if (min_h < MIN_BLOCK_HEIGHT):
            min_h = MIN_BLOCK_HEIGHT
        return min_h

    def _init_important_data(self):
        padding_in_block = 4
        min_font_pixel_height = 6
        rect = self.GetClientRect()
        consume_time = (int)(self.last_time - self.start_time)
        self._init_control_panel_important_data(rect, self.mines_num - self.red_flag_count, consume_time)
        
        self.block_height = self._calc_block_height(rect)
        self.num_txt_font = wx.Font(10, wx.FONTFAMILY_ROMAN, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_BOLD)
        font_pixel_height = self.block_height - padding_in_block
        if (font_pixel_height < min_font_pixel_height):
            font_pixel_height = min_font_pixel_height
        self.num_txt_font.SetPixelSize(wx.Size(0, font_pixel_height))

    def _adjust_dc_buffer(self):
        rect = self.GetClientRect()
        self.dc_buffer = wx.Bitmap(rect.GetWidth(), rect.GetHeight())

    def _on_size(self, event):
        self._init_important_data()
        self._adjust_dc_buffer()
        self.Refresh()

    def _on_paint(self, event):
        if (self.dc_buffer is None):
            self._adjust_dc_buffer()
        dc = wx.BufferedPaintDC(self, self.dc_buffer)
        dc.Clear()
        rect = self.GetClientRect()
        dc.SetPen(wx.BLACK_PEN)

        self.control_panel.draw(dc, rect)

        block_height = self.block_height
        if (self.num_txt_font != None):
            dc.SetFont(self.num_txt_font)

        base_x = (rect.GetWidth() - self.cols_num * block_height) // 2
        if (base_x < 0):
            base_x = 0
        base_y = self.control_panel.upper_panel_height
        self._draw_all_blocks(dc, block_height, base_x, base_y)

    def _get_one_block_rect(self, x_index, y_index):
        rect = self.GetClientRect()
        base_x = (rect.GetWidth() - self.cols_num * self.block_height) // 2
        if (base_x < 0):
            base_x = 0
        base_y = self.control_panel.upper_panel_height
        block_rect = wx.Rect(base_x + x_index * self.block_height, base_y + y_index * self.block_height, self.block_height, self.block_height)
        return block_rect

    def _need_redraw_block_rect(self, x, y, width, height):
        upd = wx.RegionIterator(self.GetUpdateRegion())
        while upd.HaveRects():
            rect = upd.GetRect()
            if (rect.Intersects(wx.Rect(x, y, width, height))):
                return True
            upd.Next()
        return False

    def _draw_all_blocks(self, dc, block_height, base_x, base_y):
        if (self.rows_num <= 0):
            return
        green_brush  = wx.Brush(wx.Colour(0, 255, 0))
        gray_brush   = wx.Brush(wx.Colour(200, 200, 200))
        for i in range(self.rows_num):
            y = base_y + i * block_height
            
            for j in range(self.cols_num):
                x = base_x + j * block_height
                locate = i * self.cols_num + j
                if (not self._need_redraw_block_rect(x, y, block_height, block_height)):
                    continue
                
                if (self.blocks_map[locate].open_state == OPEN_STATE_NOT_OPEN):
                    dc.SetBrush(green_brush)
                    dc.DrawRectangle(x, y, block_height, block_height)
                    if (self.die):
                        if (self.blocks_map[locate].has_mine):
                            self._draw_mine_icon(dc, x, y, block_height, block_height, False)
                            if (self.blocks_map[locate].block_flag == BLOCK_FLAG_REDFLAG):
                                self._draw_red_flag(dc, x, y, block_height, block_height)
                        else:
                            if (self.blocks_map[locate].block_flag == BLOCK_FLAG_REDFLAG):
                                self._draw_red_flag(dc, x, y, block_height, block_height)
                                self._draw_wrong_icon(dc, x, y, block_height, block_height)
                    else:
                        if (self.blocks_map[locate].block_flag == BLOCK_FLAG_REDFLAG):
                            self._draw_red_flag(dc, x, y, block_height, block_height)
                        elif (self.blocks_map[locate].block_flag == BLOCK_FLAG_QUESTION_MARK):
                            self._draw_question_mark(dc, x, y, block_height, block_height)
                else:
                    dc.SetBrush(gray_brush)
                    dc.DrawRectangle(x, y, block_height, block_height)
                    if (self.die):
                        if (self.blocks_map[locate].has_mine):
                            self._draw_mine_icon(dc, x, y, block_height, block_height, True)
                            self._draw_wrong_icon(dc, x, y, block_height, block_height)
                        else:
                            self._draw_mine_number(dc, self.blocks_map[locate].around_num, x, y, block_height, block_height)
                    else:
                        if (self.blocks_map[locate].around_num > 0):
                            self._draw_mine_number(dc, self.blocks_map[locate].around_num, x, y, block_height, block_height)

    def _draw_mine_icon(self, dc, x, y, width, height, is_wrong_mine):
        radius = width // 2 - 2
        if (radius < 4):
            radius = 4
        old_pen = dc.GetPen()
        if (is_wrong_mine):
            dc.SetPen(wx.Pen(wx.Colour(255, 0, 0)))
        else:
            dc.SetPen(wx.Pen(wx.Colour(0, 0, 0)))
        dc.DrawCircle(x + width // 2, y + height // 2, radius)
        dc.SetPen(old_pen)

    def _draw_wrong_icon(self, dc, x, y, width, height):
        old_pen = dc.GetPen()
        dc.SetPen(wx.Pen(wx.Colour(255, 0, 0), width=3))
        dc.DrawLine(x, y, x + width, y + height)
        dc.DrawLine(x + width, y, x, y + height)
        dc.SetPen(old_pen)

    def _draw_mine_number(self, dc, num, x, y, width, height):
        if (num <= 0 or num > 8):
            return
        old_color = dc.GetTextForeground()
        dc.SetTextForeground(self.num_colors[num])
        sz = dc.GetTextExtent(str(num))
        if (width > sz.GetWidth()):
            x = x + (int)((width - sz.GetWidth()) / 2)
        if (height > sz.GetHeight()):
            y = y + (int)((height - sz.GetHeight()) / 2)
        dc.DrawText(str(num), x, y)
        dc.SetTextForeground(old_color)

    def _draw_red_flag(self, dc, x, y, width, height):
        if (self.red_flag_icon != None):
            if (width > self.red_flag_icon.GetWidth()):
                icon_x = x + (width - self.red_flag_icon.GetWidth()) // 2
                icon_y = y + (height - self.red_flag_icon.GetHeight()) // 2
            else:
                (icon_x, icon_y) = (x, y)
            dc.DrawBitmap(self.red_flag_icon, icon_x, icon_y)
        else:
            old_pen = dc.GetPen()
            dc.SetPen(wx.Pen(wx.Colour(255, 0, 0)))
            dc.DrawLine(x+2, y+2, x+2, y+height-4)
            dc.DrawLine(x+2, y+2, x+width-4, y+height//2)
            dc.DrawLine(x+2, y+height//2, x+width-4, y+height//2)
            dc.SetPen(old_pen)

    def _draw_question_mark(self, dc, x, y, width, height):
        mark_str = "?"
        old_color = dc.GetTextForeground()
        dc.SetTextForeground(wx.Colour(0,0,0))
        sz = dc.GetTextExtent(mark_str)
        if (width > sz.GetWidth()):
            x = x + (int)((width - sz.GetWidth()) / 2)
        if (height > sz.GetHeight()):
            y = y + (int)((height - sz.GetHeight()) / 2)
        dc.DrawText(mark_str, x, y)
        dc.SetTextForeground(old_color)

    def _on_mouse_left_down(self, event):
        self.left_down = True
        self.control_panel.handle_mouse_left_down(event)

    def _quick_open(self, x, y):
        (x_index, y_index) = self._calc_x_y_index(x, y)
        block = self._get_block(x_index, y_index)
        if (block == None):
            return
        if (block.open_state == OPEN_STATE_OPENED):
            (need_refresh, update_rect) = self._try_to_open_neighbours_without_redflag(x, y)
            if (need_refresh):
                if (update_rect != None):
                    self.RefreshRect(update_rect)
                else:
                    self.Refresh()
            self._check_success()

    def _on_mouse_left_up(self, event):
        self.left_down = False
        if (self.control_panel.handle_mouse_left_up(event)):
            return
        if (self.die or self.success):
            return
        (v_x, v_y) = event.GetPosition()
        if (self.right_down):
            self._quick_open(v_x, v_y)
            return
        (need_refresh, update_rect) = self._try_to_open_a_block(v_x, v_y)
        if (need_refresh):
            if (update_rect != None):
                self.RefreshRect(update_rect)
            else:
                self.Refresh()
            if (not self.die and self.start_time == 0):
                self.start_time = time.time()
                self.last_time = self.start_time
                self.timer.Start(milliseconds=1000)
        self._check_success()

    def _on_mouse_left_dbclick(self, event):
        if (self.die or self.success):
            return
        (v_x, v_y) = event.GetPosition()
        (need_refresh, update_rect) = self._try_to_open_neighbours_without_redflag(v_x, v_y)
        if (need_refresh):
            if (update_rect != None):
                self.RefreshRect(update_rect)
            else:
                self.Refresh()
        self._check_success()

    def _check_success(self):
        if (self.die):
            return
        count = 0
        for blk in self.blocks_map:
            if (blk.open_state == OPEN_STATE_NOT_OPEN):
                count += 1
        if (count == self.mines_num):
            self.success = True
            self.timer.Stop()
            wx.MessageBox("Success", "mine sweeper")

    def _on_mouse_right_down(self, event):
        self.right_down = True
        if (self.die or self.success):
            return
        p = event.GetPosition()
        (v_x, v_y) = (p[0], p[1])
        results = self._try_to_flag_a_block(v_x, v_y)
        need_refresh = results[0]
        block_rect = results[1]
        if (need_refresh):
            if (block_rect != None):
                self.RefreshRect(block_rect)
            else:
                self.Refresh()

    def _on_mouse_right_up(self, event):
        self.right_down = False
        if (self.die or self.success):
            return
        (v_x, v_y) = event.GetPosition()
        if (self.left_down):
            self._quick_open(v_x, v_y)
            return

    def _on_timer(self, event):
        if (self.die or self.start_time == 0):
            return
        cur_time = time.time()
        #print("start_time=" + str(self.start_time) + ", cur_time=" + str(cur_time))
        self.last_time = cur_time
        consume_time = (int)(self.last_time - self.start_time)
        self.control_panel.set_consume_time(consume_time, True)

    def _calc_x_y_index(self, x, y):
        rect = self.GetClientRect()
        block_height = self.block_height

        base_x = (rect.GetWidth() - self.cols_num * block_height) // 2
        if (base_x < 0):
            base_x = 0
        base_y = self.control_panel.upper_panel_height

        x_index = (x - base_x) // block_height
        y_index = (y - base_y) // block_height
        return (x_index, y_index)

    def _try_to_open_a_block(self, x, y):
        if (self.rows_num <= 0):
            return (False, None)
        (x_index, y_index) = self._calc_x_y_index(x, y)
        (blk, update_rect) = self._set_block_opened(x_index, y_index)
        if (self.die):
            return (True, None)
        if (blk != None):
            return (True, update_rect)
        return (False, None)

    def _get_block(self, x_index, y_index):
        if (y_index >= 0 and y_index < self.rows_num and x_index >= 0 and x_index < self.cols_num):
            block_index = y_index * self.cols_num + x_index
            block = self.blocks_map[block_index]
            return block
        return None

    def _try_to_open_neighbours_without_redflag(self, x, y):
        total_around = 0
        seen_mine_around = 0
        count = 0
        (x_index, y_index) = self._calc_x_y_index(x, y)
        block = self._get_block(x_index, y_index)
        if (block == None):
            return (False, None)
        total_around = block.around_num
        neighbours = [(x_index - 1, y_index - 1), (x_index, y_index - 1), (x_index + 1, y_index - 1), \
                      (x_index - 1, y_index), (x_index + 1, y_index), \
                      (x_index - 1, y_index + 1), (x_index, y_index + 1), (x_index + 1, y_index + 1)]
        for nb_pnt in neighbours:
            blk = self._get_block(nb_pnt[0], nb_pnt[1])
            if (blk != None):
                if ((blk.open_state == OPEN_STATE_OPENED and blk.has_mine) or (blk.open_state == OPEN_STATE_NOT_OPEN and blk.block_flag == BLOCK_FLAG_REDFLAG)):
                    seen_mine_around += 1
        if (seen_mine_around < total_around):
            return (False, None)

        nbs_rect = None
        for nb_pnt in neighbours:
            (blk, update_rect) = self._set_block_opened(nb_pnt[0], nb_pnt[1])
            if (self.die):
                return (True, None)
            if (blk != None):
                count += 1
                if (nbs_rect != None):
                    if (update_rect != None):
                        nbs_rect.Union(update_rect)
                else:
                    nbs_rect = update_rect
        
        if (count > 0):
            return (True, nbs_rect)
        return (False, None)

    def _recursive_open_all_neighbour_zero_around(self, x_index, y_index):
        new_open_blocks = []
        neighbours = [(x_index - 1, y_index - 1), (x_index, y_index - 1), (x_index + 1, y_index - 1), \
                      (x_index - 1, y_index), (x_index + 1, y_index), \
                      (x_index - 1, y_index + 1), (x_index, y_index + 1), (x_index + 1, y_index + 1)]
        if (self.die):
            return None
        blocks_rect = None
        for nb_pnt in neighbours:
            (blk, update_rect) = self._set_block_opened(nb_pnt[0], nb_pnt[1])
            if (self.die):
                return None
            if (blk != None):
                new_open_blocks.append((nb_pnt, blk))
                if (blocks_rect != None):
                    if (update_rect != None):
                        blocks_rect.Union(update_rect)
                else:
                    blocks_rect = update_rect
            else:
                pass
        return blocks_rect

    def _set_block_opened(self, x_index, y_index):
        if (self.die):
            return (None, None)
        update_rect = None
        if (y_index >= 0 and y_index < self.rows_num and x_index >= 0 and x_index < self.cols_num):
            block_index = y_index * self.cols_num + x_index
            block = self.blocks_map[block_index]
            if (block.open_state == OPEN_STATE_NOT_OPEN and block.block_flag != BLOCK_FLAG_REDFLAG):
                block.open_state = OPEN_STATE_OPENED
                update_rect = self._get_one_block_rect(x_index, y_index)
                if (block.has_mine):
                    print("!!!!! DIE !!!!!!" + "   x_index=" + str(x_index) + ", y_index=" + str(y_index))
                    self.die = True
                    self.timer.Stop()
                else:
                    if (block.around_num == 0):
                        rect = self._recursive_open_all_neighbour_zero_around(x_index, y_index)
                        if (update_rect != None and rect != None):
                            update_rect.Union(rect)
                return (block, update_rect)
        return (None, update_rect)

    def _try_to_flag_a_block(self, x, y):
        if (self.rows_num <= 0):
            return (False, None)

        (x_index, y_index) = self._calc_x_y_index(x, y)
        if (x_index >= 0 and x_index < self.cols_num and y_index >= 0 and y_index < self.rows_num):
            block_index = y_index * self.cols_num + x_index
            block = self.blocks_map[block_index]
            if (block.open_state == OPEN_STATE_OPENED):
                return (False, None)
            if (block.block_flag == BLOCK_FLAG_NONE):
                block.block_flag = BLOCK_FLAG_REDFLAG
                self.red_flag_count += 1
                self.control_panel.set_residue_mines_count(self.mines_num - self.red_flag_count, True)
                block_rect = self._get_one_block_rect(x_index, y_index)
                return (True, block_rect)
            elif (block.block_flag == BLOCK_FLAG_REDFLAG):
                block.block_flag = BLOCK_FLAG_QUESTION_MARK
                self.red_flag_count -= 1
                self.control_panel.set_residue_mines_count(self.mines_num - self.red_flag_count, True)
                block_rect = self._get_one_block_rect(x_index, y_index)
                return (True, block_rect)
            elif (block.block_flag == BLOCK_FLAG_QUESTION_MARK):
                block.block_flag = BLOCK_FLAG_NONE
                block_rect = self._get_one_block_rect(x_index, y_index)
                return (True, block_rect)
        return (False, None)
